# -*- coding: utf-8 -*-
from numpy import array, zeros, full, argmin, inf, ndim
from scipy.spatial.distance import cdist
from math import isinf
import numpy as np
import traceback
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler,StandardScaler
from copy import deepcopy

def dtw(x, y, dist, warp=1, w=inf, s=1.0):
#def dtw(x, y, dist, warp=0, w=inf, s=0):
    """
    Computes Dynamic Time Warping (DTW) of two sequences.

    :param array x: N1*M array
    :param array y: N2*M array
    :param func dist: distance used as cost measure
    :param int warp: how many shifts are computed.
    :param int w: window size limiting the maximal distance between indices of matched entries |i,j|.
    :param float s: weight applied on off-diagonal moves of the path. As s gets larger, the warping path is increasingly biased towards the diagonal
    Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
    """
    assert len(x)
    assert len(y)
    assert isinf(w) or (w >= abs(len(x) - len(y)))
    assert s > 0
    r, c = len(x), len(y)
    if not isinf(w):
        D0 = full((r + 1, c + 1), inf)
        for i in range(1, r + 1):
            D0[i, max(1, i - w):min(c + 1, i + w + 1)] = 0
        D0[0, 0] = 0
    else:
        D0 = zeros((r + 1, c + 1))
        D0[0, 1:] = inf
        D0[1:, 0] = inf
    D1 = D0[1:, 1:]  # view
    for i in range(r):
        for j in range(c):
            if (isinf(w) or (max(0, i - w) <= j <= min(c, i + w))):
                #D1[i, j] = dist(x[i], y[j])
                D1[i, j] = dist(np.array(x[i]).reshape(1, -1), np.array(y[j]).reshape(1, -1))
    C = D1.copy()
    jrange = range(c)
    for i in range(r):
        if not isinf(w):
            jrange = range(max(0, i - w), min(c, i + w + 1))
        for j in jrange:
            min_list = [D0[i, j]]
            for k in range(1, warp + 1):
                i_k = min(i + k, r)
                j_k = min(j + k, c)
                min_list += [D0[i_k, j] * s, D0[i, j_k] * s]
            D1[i, j] += min(min_list)
    if len(x) == 1:
        path = zeros(len(y)), range(len(y))
    elif len(y) == 1:
        path = range(len(x)), zeros(len(x))
    else:
        path = _traceback(D0)
    return D1[-1, -1], C, D1, path


def accelerated_dtw(x, y, dist, warp=1):
#def accelerated_dtw(x, y, dist, warp=0):
    """
    Computes Dynamic Time Warping (DTW) of two sequences in a faster way.
    Instead of iterating through each element and calculating each distance,
    this uses the cdist function from scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)

    :param array x: N1*M array
    :param array y: N2*M array
    :param string or func dist: distance parameter for cdist. When string is given, cdist uses optimized functions for the distance metrics.
    If a string is passed, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
    :param int warp: how many shifts are computed.
    Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
    """
    assert len(x)
    assert len(y)
    if ndim(x) == 1:
        #x = x.reshape(-1, 1)
        x = np.array(x).reshape(-1, 1)
        #print (x)
    if ndim(y) == 1:
        #y = y.reshape(-1, 1)
        y = np.array(y).reshape(-1, 1)
        #print (y)
    r, c = len(x), len(y)
    D0 = zeros((r + 1, c + 1))
    D0[0, 1:] = inf
    D0[1:, 0] = inf
    D1 = D0[1:, 1:]
    D0[1:, 1:] = cdist(x, y, dist)
    #for i in D0:
    #    print (i)
    C = D1.copy()
    for i in range(r):
        for j in range(c):
            min_list = [D0[i, j]]
            for k in range(1, warp + 1):
                min_list += [D0[min(i + k, r), j],
                             D0[i, min(j + k, c)]]
            D1[i, j] += min(min_list)
    if len(x) == 1:
        path = zeros(len(y)), range(len(y))
    elif len(y) == 1:
        path = range(len(x)), zeros(len(x))
    else:
        path = _traceback(D0)
    return D1[-1, -1], C, D1, path


def _traceback(D):
    i, j = array(D.shape) - 2
    p, q = [i], [j]
    while (i > 0) or (j > 0):
        tb = argmin((D[i, j], D[i, j + 1], D[i + 1, j]))
        if tb == 0:
            i -= 1
            j -= 1
        elif tb == 1:
            i -= 1
        else:  # (tb == 2):
            j -= 1
        p.insert(0, i)
        q.insert(0, j)
    return array(p), array(q)


def nomarize_data(idata):
    try:
        a = []
        b = {}
        for i in idata:
            a.append([i,idata[i]])
        #mm = MinMaxScaler(feature_range=(0, 100))
        mm = StandardScaler()
        x = mm.fit_transform(a)
        for i in range(len(x)):
            b[i] = int(x[i][1])
        return b
    except:
        traceback.print_exc()
        return idata

"""
sample_dict = {}
sample_dict["timestamp"] = 0
sample_dict["base"] = {0:0,1:1,2:2}
sample_dict["item"] = {}
sample_dict["dist"] = {}
sample_dict["item"]["sample"]= {0:0,1:1,2:2}
"""
class Dtw(object):
    def __init__(self):
        mdata = {"timestamp":0,"base":{},"item":{},"dist":[]}
        self.plt_color = ["k","y","g","c","b","m","darkred","gold","skyblue","darkblue","pink"]
        self.cnum = len(self.plt_color)
        self.mdata = mdata
        self.pdata = mdata
        self.ret = {}
        self.ret["success"] = False
        self.ret["cdist"] = "mahalanobis"
        #self.ret["cdist"] = 'minkowski'
        #self.ret["cdist"] = "euclidean"
        self.ret["item"] = {}
        self.ret["base_item_name"] = ""

    def dtw_pre(self,idata):
        #print ("dtw_pre todo")
        self.mdata = deepcopy(idata)
        self.pdata = deepcopy(idata)
        try:
            #print (len(self.mdata["item"]))
            self.pdata["base"]["value"] = nomarize_data(self.pdata["base"]["value"])
            for item in list(idata["item"].keys()):
                item_values_l = list(idata["item"][item]["value"].values())
                #print (item_values_l)
                if item_values_l[1:] == item_values_l[:-1]:
                    #del self.mdata["item"][item]
                    del self.pdata["item"][item]
                else:
                    self.pdata["item"][item]["value"] = nomarize_data(self.pdata["item"][item]["value"])
                    #print (len(self.mdata["item"]))
            #print (len(self.mdata["item"]))
            #print (len(self.pdata["item"]))
            return self.pdata

        except:
            traceback.print_exc()
            return idata

    def dtw_loop(self,mdata):
        self.pdata = deepcopy(mdata)
        #print(self.mdata)
        try:
            dist_t = {}
            x = list(self.pdata["base"]["value"].values())
            for item in self.pdata["item"]:
                #x = list(self.pdata["base"].values())
                y = list(self.pdata["item"][item]["value"].values())
                y_ts = self.pdata["item"][item]["timestamp"]
                self.ret["item"][item] = {}
                self.ret["item"][item]["value_list"] = y
                self.ret["item"][item]["timestamp_list"] = y_ts

                dist, cost, acc, path = accelerated_dtw(x, y, self.ret["cdist"], warp=1)
                self.ret["item"][item]["dist"] = dist
                dist_t[item] = dist
            self.mdata["dist"] = sorted(dist_t.items(), key=lambda x:x[1], reverse=False)
            self.pdata["dist"] = self.mdata["dist"]
            #print (self.mdata["dist"])
            self.ret["base"] = {}
            self.ret["base"]["timestamp"] = self.pdata["base"]["timestamp"]
            self.ret["base"]["value"] = list(self.pdata["base"]["value"].values())
            self.ret["dist"] = self.mdata["dist"]
            self.ret["base_item_name"] = self.mdata["base_item_name"]

        except:
            traceback.print_exc()
            pass
        return self.ret

    def dtw_plot_single(self,filepath="./single_test.png",item_t="cpu_total-user"):
        plt.clf()
        plt.plot(range(len(self.mdata["item"][item_t]["value"].keys())), list(self.mdata["item"][item_t]["value"].values()), color='r', linewidth=1.5, linestyle='-', label="%s-%s"%(item_t,self.ret["item"][item_t]["dist"]))
        plt.legend(loc='upper left',fontsize=6)
        plt.savefig(filepath)
        x = list(self.mdata["base"]["value"].values())
        y = list(self.mdata["item"][item_t]["value"].values())
        x_f = []
        y_f = []
        for i in self.mdata["base"]["value"]:
            x_f.append([i,self.mdata["base"]["value"][i]])
        for i in self.mdata["item"][item_t]["value"]:
            y_f.append([i,self.mdata["item"][item_t]["value"][i]])
        #print ("---------\n%s\n%s\n\n%s\n%s\n--------"%("base",x_f,item_t,y_f))
        dist, cost, acc, path = accelerated_dtw(x, y, self.ret["cdist"], warp=1)

        plt.clf()
        plt.imshow(cost.T, origin='lower', cmap=plt.cm.Reds, interpolation='nearest')
        plt.plot(path[0], path[1], '-o')  # relation
        plt.xticks(range(len(x)), x)
        plt.yticks(range(len(y)), y)
        plt.xlabel('x')
        plt.ylabel('y')
        plt.axis('tight')
        plt.savefig("%s-path"%item_t)

    def dtw_plot(self,filepath="./dtw_plot_tmp.png",cnt=10):
        #print ("dtw_plot todo")
        num_y = 5
        num_x = 4
        plt.clf()
        fig=plt.figure(figsize=(9,6),dpi=200,facecolor='w')
        plt1 = plt.subplot(num_y,num_x,1)
        plt_l = []
        plt1.plot(range(len(self.mdata["base"]["value"].keys())), list(self.mdata["base"]["value"].values()), color='r', linewidth=1.5, linestyle='-', label='base')
        plt1.set_title(self.mdata["base_item_name"],fontsize=6)
        plt1.tick_params(labelsize=4) 
        plt1.legend(loc='upper left',fontsize=4)
        num = 0
        plt_num = 1
        #num_t = 0
        for i in self.mdata["dist"]:
            #plt.clf()
            item = i[0]
            plt_num += 1
            plt_tmp = plt.subplot(num_y,num_x,plt_num)
            plt_tmp.plot(range(len(self.mdata["item"][item]["value"].keys())), list(self.mdata["item"][item]["value"].values()),
            #         '%s'%self.plt_color[num%self.cnum], label='%s,dist=%.2f'%(item,self.ret["item"][item]["dist"]))
                     '%s'%self.plt_color[num%self.cnum], label='dist=%.2f'%self.ret["item"][item]["dist"])
            plt_tmp.set_title('%s'%(item),fontsize=6)
            plt_tmp.tick_params(labelsize=4) 
            #plt_tmp.xlabel(fontsize=4)
            #plt_tmp.ylabel(fontsize=4)
            num += 1
            #plt_num += 1
            plt_tmp.legend(loc='upper left',fontsize=4)
            if num >= cnt:
                break
            if num >= num_y*num_x - 1:
                break
        plt.tight_layout()
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
                           wspace=None, hspace=1)
        plt.savefig(filepath)
        #plt.legend(loc='upper left',fontsize=6)
        #plt.show()
        #plt.savefig(filepath)

if __name__ == '__main__':
    w = inf
    s = 1.0
    if 1:  # 1-D numeric
        from sklearn.metrics.pairwise import manhattan_distances
        x = [0, 0, 1, 1, 2, 4, 2, 1, 2, 0]
        y = [1, 1, 1, 2, 2, 2, 2, 3, 2, 0]
        dist_fun = manhattan_distances
        w = 1
        # s = 1.2
    elif 0:  # 2-D numeric
        from sklearn.metrics.pairwise import euclidean_distances
        x = [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [4, 3], [2, 3], [1, 1], [2, 2], [0, 1]]
        y = [[1, 0], [1, 1], [1, 1], [2, 1], [4, 3], [4, 3], [2, 3], [3, 1], [1, 2], [1, 0]]
        dist_fun = euclidean_distances
    else:  # 1-D list of strings
        from nltk.metrics.distance import edit_distance
        # x = ['we', 'shelled', 'clams', 'for', 'the', 'chowder']
        # y = ['class', 'too']
        x = ['i', 'soon', 'found', 'myself', 'muttering', 'to', 'the', 'walls']
        y = ['see', 'drown', 'himself']
        # x = 'we talked about the situation'.split()
        # y = 'we talked about the situation'.split()
        dist_fun = edit_distance
    dist, cost, acc, path = dtw(x, y, dist_fun, w=w, s=s)
    print (dist,"\n", cost,"\n", acc,"\n", path )
    dist, cost, acc, path = accelerated_dtw(x, y, "mahalanobis", warp=1)
    print (dist,"\n", cost,"\n", acc,"\n", path )

    # Vizualize
    from matplotlib import pyplot as plt
    plt.imshow(cost.T, origin='lower', cmap=plt.cm.Reds, interpolation='nearest')
    plt.plot(path[0], path[1], '-o')  # relation
    plt.xticks(range(len(x)), x)
    plt.yticks(range(len(y)), y)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.axis('tight')
    if isinf(w):
        plt.title('Minimum distance: {}, slope weight: {}'.format(dist, s))
    else:
        plt.title('Minimum distance: {}, window widht: {}, slope weight: {}'.format(dist, w, s))
    plt.show()
    plt.savefig('test1.png')
